{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Multilayer Perceptron with Sequential Wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([64, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.1\n",
    "num_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "# Architecture\n",
    "num_features = 784\n",
    "num_hidden_1 = 128\n",
    "num_hidden_2 = 256\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "class MultilayerPerceptron(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super(MultilayerPerceptron, self).__init__()\n",
    "        \n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(num_features, num_hidden_1),\n",
    "            torch.nn.ReLU(inplace=True),\n",
    "            torch.nn.Linear(num_hidden_1, num_hidden_2),\n",
    "            torch.nn.ReLU(inplace=True),\n",
    "            torch.nn.Linear(num_hidden_2, num_classes)\n",
    "        )\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        logits = self.net(x)\n",
    "        probas = F.log_softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = MultilayerPerceptron(num_features=num_features,\n",
    "                             num_classes=num_classes)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/938 | Cost: 2.3075\n",
      "Epoch: 001/010 | Batch 050/938 | Cost: 1.7222\n",
      "Epoch: 001/010 | Batch 100/938 | Cost: 0.7172\n",
      "Epoch: 001/010 | Batch 150/938 | Cost: 0.5022\n",
      "Epoch: 001/010 | Batch 200/938 | Cost: 0.3913\n",
      "Epoch: 001/010 | Batch 250/938 | Cost: 0.4786\n",
      "Epoch: 001/010 | Batch 300/938 | Cost: 0.4177\n",
      "Epoch: 001/010 | Batch 350/938 | Cost: 0.1392\n",
      "Epoch: 001/010 | Batch 400/938 | Cost: 0.2752\n",
      "Epoch: 001/010 | Batch 450/938 | Cost: 0.2991\n",
      "Epoch: 001/010 | Batch 500/938 | Cost: 0.3828\n",
      "Epoch: 001/010 | Batch 550/938 | Cost: 0.2604\n",
      "Epoch: 001/010 | Batch 600/938 | Cost: 0.2135\n",
      "Epoch: 001/010 | Batch 650/938 | Cost: 0.3743\n",
      "Epoch: 001/010 | Batch 700/938 | Cost: 0.1834\n",
      "Epoch: 001/010 | Batch 750/938 | Cost: 0.2983\n",
      "Epoch: 001/010 | Batch 800/938 | Cost: 0.1182\n",
      "Epoch: 001/010 | Batch 850/938 | Cost: 0.1066\n",
      "Epoch: 001/010 | Batch 900/938 | Cost: 0.3104\n",
      "Epoch: 001/010 training accuracy: 92.76%\n",
      "Time elapsed: 0.21 min\n",
      "Epoch: 002/010 | Batch 000/938 | Cost: 0.2066\n",
      "Epoch: 002/010 | Batch 050/938 | Cost: 0.1977\n",
      "Epoch: 002/010 | Batch 100/938 | Cost: 0.1766\n",
      "Epoch: 002/010 | Batch 150/938 | Cost: 0.3247\n",
      "Epoch: 002/010 | Batch 200/938 | Cost: 0.2914\n",
      "Epoch: 002/010 | Batch 250/938 | Cost: 0.3427\n",
      "Epoch: 002/010 | Batch 300/938 | Cost: 0.0698\n",
      "Epoch: 002/010 | Batch 350/938 | Cost: 0.1212\n",
      "Epoch: 002/010 | Batch 400/938 | Cost: 0.1271\n",
      "Epoch: 002/010 | Batch 450/938 | Cost: 0.1743\n",
      "Epoch: 002/010 | Batch 500/938 | Cost: 0.0741\n",
      "Epoch: 002/010 | Batch 550/938 | Cost: 0.2402\n",
      "Epoch: 002/010 | Batch 600/938 | Cost: 0.2013\n",
      "Epoch: 002/010 | Batch 650/938 | Cost: 0.1400\n",
      "Epoch: 002/010 | Batch 700/938 | Cost: 0.1400\n",
      "Epoch: 002/010 | Batch 750/938 | Cost: 0.4599\n",
      "Epoch: 002/010 | Batch 800/938 | Cost: 0.2720\n",
      "Epoch: 002/010 | Batch 850/938 | Cost: 0.1009\n",
      "Epoch: 002/010 | Batch 900/938 | Cost: 0.2104\n",
      "Epoch: 002/010 training accuracy: 95.64%\n",
      "Time elapsed: 0.42 min\n",
      "Epoch: 003/010 | Batch 000/938 | Cost: 0.1309\n",
      "Epoch: 003/010 | Batch 050/938 | Cost: 0.0440\n",
      "Epoch: 003/010 | Batch 100/938 | Cost: 0.0876\n",
      "Epoch: 003/010 | Batch 150/938 | Cost: 0.1927\n",
      "Epoch: 003/010 | Batch 200/938 | Cost: 0.1592\n",
      "Epoch: 003/010 | Batch 250/938 | Cost: 0.1010\n",
      "Epoch: 003/010 | Batch 300/938 | Cost: 0.1311\n",
      "Epoch: 003/010 | Batch 350/938 | Cost: 0.2633\n",
      "Epoch: 003/010 | Batch 400/938 | Cost: 0.2272\n",
      "Epoch: 003/010 | Batch 450/938 | Cost: 0.2475\n",
      "Epoch: 003/010 | Batch 500/938 | Cost: 0.1742\n",
      "Epoch: 003/010 | Batch 550/938 | Cost: 0.0937\n",
      "Epoch: 003/010 | Batch 600/938 | Cost: 0.2019\n",
      "Epoch: 003/010 | Batch 650/938 | Cost: 0.1171\n",
      "Epoch: 003/010 | Batch 700/938 | Cost: 0.1200\n",
      "Epoch: 003/010 | Batch 750/938 | Cost: 0.1760\n",
      "Epoch: 003/010 | Batch 800/938 | Cost: 0.0595\n",
      "Epoch: 003/010 | Batch 850/938 | Cost: 0.1174\n",
      "Epoch: 003/010 | Batch 900/938 | Cost: 0.1585\n",
      "Epoch: 003/010 training accuracy: 96.93%\n",
      "Time elapsed: 0.63 min\n",
      "Epoch: 004/010 | Batch 000/938 | Cost: 0.0601\n",
      "Epoch: 004/010 | Batch 050/938 | Cost: 0.0644\n",
      "Epoch: 004/010 | Batch 100/938 | Cost: 0.1762\n",
      "Epoch: 004/010 | Batch 150/938 | Cost: 0.2237\n",
      "Epoch: 004/010 | Batch 200/938 | Cost: 0.0488\n",
      "Epoch: 004/010 | Batch 250/938 | Cost: 0.0304\n",
      "Epoch: 004/010 | Batch 300/938 | Cost: 0.1097\n",
      "Epoch: 004/010 | Batch 350/938 | Cost: 0.1154\n",
      "Epoch: 004/010 | Batch 400/938 | Cost: 0.2170\n",
      "Epoch: 004/010 | Batch 450/938 | Cost: 0.0193\n",
      "Epoch: 004/010 | Batch 500/938 | Cost: 0.0457\n",
      "Epoch: 004/010 | Batch 550/938 | Cost: 0.0845\n",
      "Epoch: 004/010 | Batch 600/938 | Cost: 0.0482\n",
      "Epoch: 004/010 | Batch 650/938 | Cost: 0.0267\n",
      "Epoch: 004/010 | Batch 700/938 | Cost: 0.1988\n",
      "Epoch: 004/010 | Batch 750/938 | Cost: 0.0505\n",
      "Epoch: 004/010 | Batch 800/938 | Cost: 0.2189\n",
      "Epoch: 004/010 | Batch 850/938 | Cost: 0.0378\n",
      "Epoch: 004/010 | Batch 900/938 | Cost: 0.1241\n",
      "Epoch: 004/010 training accuracy: 97.57%\n",
      "Time elapsed: 0.84 min\n",
      "Epoch: 005/010 | Batch 000/938 | Cost: 0.0834\n",
      "Epoch: 005/010 | Batch 050/938 | Cost: 0.1044\n",
      "Epoch: 005/010 | Batch 100/938 | Cost: 0.0275\n",
      "Epoch: 005/010 | Batch 150/938 | Cost: 0.0497\n",
      "Epoch: 005/010 | Batch 200/938 | Cost: 0.1309\n",
      "Epoch: 005/010 | Batch 250/938 | Cost: 0.1043\n",
      "Epoch: 005/010 | Batch 300/938 | Cost: 0.0290\n",
      "Epoch: 005/010 | Batch 350/938 | Cost: 0.0926\n",
      "Epoch: 005/010 | Batch 400/938 | Cost: 0.0186\n",
      "Epoch: 005/010 | Batch 450/938 | Cost: 0.1377\n",
      "Epoch: 005/010 | Batch 500/938 | Cost: 0.0227\n",
      "Epoch: 005/010 | Batch 550/938 | Cost: 0.0664\n",
      "Epoch: 005/010 | Batch 600/938 | Cost: 0.0825\n",
      "Epoch: 005/010 | Batch 650/938 | Cost: 0.0761\n",
      "Epoch: 005/010 | Batch 700/938 | Cost: 0.0321\n",
      "Epoch: 005/010 | Batch 750/938 | Cost: 0.0946\n",
      "Epoch: 005/010 | Batch 800/938 | Cost: 0.0219\n",
      "Epoch: 005/010 | Batch 850/938 | Cost: 0.0287\n",
      "Epoch: 005/010 | Batch 900/938 | Cost: 0.1176\n",
      "Epoch: 005/010 training accuracy: 97.35%\n",
      "Time elapsed: 1.04 min\n",
      "Epoch: 006/010 | Batch 000/938 | Cost: 0.0235\n",
      "Epoch: 006/010 | Batch 050/938 | Cost: 0.0705\n",
      "Epoch: 006/010 | Batch 100/938 | Cost: 0.0529\n",
      "Epoch: 006/010 | Batch 150/938 | Cost: 0.0226\n",
      "Epoch: 006/010 | Batch 200/938 | Cost: 0.0313\n",
      "Epoch: 006/010 | Batch 250/938 | Cost: 0.0317\n",
      "Epoch: 006/010 | Batch 300/938 | Cost: 0.0393\n",
      "Epoch: 006/010 | Batch 350/938 | Cost: 0.0317\n",
      "Epoch: 006/010 | Batch 400/938 | Cost: 0.0341\n",
      "Epoch: 006/010 | Batch 450/938 | Cost: 0.0347\n",
      "Epoch: 006/010 | Batch 500/938 | Cost: 0.0658\n",
      "Epoch: 006/010 | Batch 550/938 | Cost: 0.0223\n",
      "Epoch: 006/010 | Batch 600/938 | Cost: 0.1233\n",
      "Epoch: 006/010 | Batch 650/938 | Cost: 0.0330\n",
      "Epoch: 006/010 | Batch 700/938 | Cost: 0.0122\n",
      "Epoch: 006/010 | Batch 750/938 | Cost: 0.0728\n",
      "Epoch: 006/010 | Batch 800/938 | Cost: 0.1590\n",
      "Epoch: 006/010 | Batch 850/938 | Cost: 0.0982\n",
      "Epoch: 006/010 | Batch 900/938 | Cost: 0.0298\n",
      "Epoch: 006/010 training accuracy: 98.46%\n",
      "Time elapsed: 1.25 min\n",
      "Epoch: 007/010 | Batch 000/938 | Cost: 0.0400\n",
      "Epoch: 007/010 | Batch 050/938 | Cost: 0.1568\n",
      "Epoch: 007/010 | Batch 100/938 | Cost: 0.0724\n",
      "Epoch: 007/010 | Batch 150/938 | Cost: 0.2265\n",
      "Epoch: 007/010 | Batch 200/938 | Cost: 0.0221\n",
      "Epoch: 007/010 | Batch 250/938 | Cost: 0.0142\n",
      "Epoch: 007/010 | Batch 300/938 | Cost: 0.0837\n",
      "Epoch: 007/010 | Batch 350/938 | Cost: 0.1274\n",
      "Epoch: 007/010 | Batch 400/938 | Cost: 0.0372\n",
      "Epoch: 007/010 | Batch 450/938 | Cost: 0.0902\n",
      "Epoch: 007/010 | Batch 500/938 | Cost: 0.0803\n",
      "Epoch: 007/010 | Batch 550/938 | Cost: 0.0229\n",
      "Epoch: 007/010 | Batch 600/938 | Cost: 0.0453\n",
      "Epoch: 007/010 | Batch 650/938 | Cost: 0.0195\n",
      "Epoch: 007/010 | Batch 700/938 | Cost: 0.1837\n",
      "Epoch: 007/010 | Batch 750/938 | Cost: 0.0499\n",
      "Epoch: 007/010 | Batch 800/938 | Cost: 0.0406\n",
      "Epoch: 007/010 | Batch 850/938 | Cost: 0.0500\n",
      "Epoch: 007/010 | Batch 900/938 | Cost: 0.0717\n",
      "Epoch: 007/010 training accuracy: 98.65%\n",
      "Time elapsed: 1.46 min\n",
      "Epoch: 008/010 | Batch 000/938 | Cost: 0.0179\n",
      "Epoch: 008/010 | Batch 050/938 | Cost: 0.0589\n",
      "Epoch: 008/010 | Batch 100/938 | Cost: 0.0335\n",
      "Epoch: 008/010 | Batch 150/938 | Cost: 0.0211\n",
      "Epoch: 008/010 | Batch 200/938 | Cost: 0.0545\n",
      "Epoch: 008/010 | Batch 250/938 | Cost: 0.0219\n",
      "Epoch: 008/010 | Batch 300/938 | Cost: 0.0395\n",
      "Epoch: 008/010 | Batch 350/938 | Cost: 0.1509\n",
      "Epoch: 008/010 | Batch 400/938 | Cost: 0.1123\n",
      "Epoch: 008/010 | Batch 450/938 | Cost: 0.0262\n",
      "Epoch: 008/010 | Batch 500/938 | Cost: 0.1050\n",
      "Epoch: 008/010 | Batch 550/938 | Cost: 0.0804\n",
      "Epoch: 008/010 | Batch 600/938 | Cost: 0.0080\n",
      "Epoch: 008/010 | Batch 650/938 | Cost: 0.0510\n",
      "Epoch: 008/010 | Batch 700/938 | Cost: 0.0269\n",
      "Epoch: 008/010 | Batch 750/938 | Cost: 0.0175\n",
      "Epoch: 008/010 | Batch 800/938 | Cost: 0.0942\n",
      "Epoch: 008/010 | Batch 850/938 | Cost: 0.0452\n",
      "Epoch: 008/010 | Batch 900/938 | Cost: 0.0179\n",
      "Epoch: 008/010 training accuracy: 98.79%\n",
      "Time elapsed: 1.67 min\n",
      "Epoch: 009/010 | Batch 000/938 | Cost: 0.0745\n",
      "Epoch: 009/010 | Batch 050/938 | Cost: 0.0414\n",
      "Epoch: 009/010 | Batch 100/938 | Cost: 0.1068\n",
      "Epoch: 009/010 | Batch 150/938 | Cost: 0.0644\n",
      "Epoch: 009/010 | Batch 200/938 | Cost: 0.0175\n",
      "Epoch: 009/010 | Batch 250/938 | Cost: 0.0171\n",
      "Epoch: 009/010 | Batch 300/938 | Cost: 0.0626\n",
      "Epoch: 009/010 | Batch 350/938 | Cost: 0.1016\n",
      "Epoch: 009/010 | Batch 400/938 | Cost: 0.0094\n",
      "Epoch: 009/010 | Batch 450/938 | Cost: 0.0147\n",
      "Epoch: 009/010 | Batch 500/938 | Cost: 0.0191\n",
      "Epoch: 009/010 | Batch 550/938 | Cost: 0.0259\n",
      "Epoch: 009/010 | Batch 600/938 | Cost: 0.0519\n",
      "Epoch: 009/010 | Batch 650/938 | Cost: 0.0041\n",
      "Epoch: 009/010 | Batch 700/938 | Cost: 0.0307\n",
      "Epoch: 009/010 | Batch 750/938 | Cost: 0.0121\n",
      "Epoch: 009/010 | Batch 800/938 | Cost: 0.0308\n",
      "Epoch: 009/010 | Batch 850/938 | Cost: 0.0094\n",
      "Epoch: 009/010 | Batch 900/938 | Cost: 0.0168\n",
      "Epoch: 009/010 training accuracy: 99.11%\n",
      "Time elapsed: 1.87 min\n",
      "Epoch: 010/010 | Batch 000/938 | Cost: 0.0393\n",
      "Epoch: 010/010 | Batch 050/938 | Cost: 0.0156\n",
      "Epoch: 010/010 | Batch 100/938 | Cost: 0.0285\n",
      "Epoch: 010/010 | Batch 150/938 | Cost: 0.0080\n",
      "Epoch: 010/010 | Batch 200/938 | Cost: 0.0148\n",
      "Epoch: 010/010 | Batch 250/938 | Cost: 0.0367\n",
      "Epoch: 010/010 | Batch 300/938 | Cost: 0.0511\n",
      "Epoch: 010/010 | Batch 350/938 | Cost: 0.0230\n",
      "Epoch: 010/010 | Batch 400/938 | Cost: 0.0563\n",
      "Epoch: 010/010 | Batch 450/938 | Cost: 0.0435\n",
      "Epoch: 010/010 | Batch 500/938 | Cost: 0.0626\n",
      "Epoch: 010/010 | Batch 550/938 | Cost: 0.0835\n",
      "Epoch: 010/010 | Batch 600/938 | Cost: 0.1073\n",
      "Epoch: 010/010 | Batch 650/938 | Cost: 0.0313\n",
      "Epoch: 010/010 | Batch 700/938 | Cost: 0.0279\n",
      "Epoch: 010/010 | Batch 750/938 | Cost: 0.0343\n",
      "Epoch: 010/010 | Batch 800/938 | Cost: 0.1145\n",
      "Epoch: 010/010 | Batch 850/938 | Cost: 0.0085\n",
      "Epoch: 010/010 | Batch 900/938 | Cost: 0.0067\n",
      "Epoch: 010/010 training accuracy: 99.33%\n",
      "Time elapsed: 2.08 min\n",
      "Total Training Time: 2.08 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(net, data_loader):\n",
    "    net.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for features, targets in data_loader:\n",
    "            features = features.view(-1, 28*28).to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits, probas = net(features)\n",
    "            _, predicted_labels = torch.max(probas, 1)\n",
    "            num_examples += targets.size(0)\n",
    "            correct_pred += (predicted_labels == targets).sum()\n",
    "        return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.view(-1, 28*28).to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.82%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Accessing Intermediate Results via Hooks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One disadvantage of the Sequential wrapper is that we cannot readily access (or \"print\") intermediate values. However, we can use custom hooks for that. For instance, the order of operations in our Sequential wrapper is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=784, out_features=128, bias=True)\n",
       "  (1): ReLU(inplace)\n",
       "  (2): Linear(in_features=128, out_features=256, bias=True)\n",
       "  (3): ReLU(inplace)\n",
       "  (4): Linear(in_features=256, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want to get the output from the 2nd layer during the forward pass, we can register a hook as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.hooks.RemovableHandle at 0x7f659c6685c0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = []\n",
    "def hook(module, input, output):\n",
    "    outputs.append(output)\n",
    "\n",
    "model.net[2].register_forward_hook(hook)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, if we call the model on some inputs, it will save the intermediate results in the \"outputs\" list:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[0.5341, 1.0513, 2.3542,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.6676, 0.6620,  ..., 0.0000, 0.0000, 2.4056],\n",
      "        [1.1520, 0.0000, 0.0000,  ..., 2.5860, 0.8992, 0.9642],\n",
      "        ...,\n",
      "        [0.0000, 0.1076, 0.0000,  ..., 1.8367, 0.0000, 2.5203],\n",
      "        [0.5415, 0.0000, 0.0000,  ..., 2.7968, 0.8244, 1.6335],\n",
      "        [1.0710, 0.9805, 3.0103,  ..., 0.0000, 0.0000, 0.0000]],\n",
      "       device='cuda:3', grad_fn=<ThresholdBackward1>)]\n"
     ]
    }
   ],
   "source": [
    "_ = model(features)\n",
    "\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
